# 非极大值抑制剔除置信度小的框

import tensorflow as tf
import numpy as np
from anchors import anchors_prediction  # 导入对网络每一输出层的解码方法

# 对所有输出层解码，并使用非极大值抑制
def predict(feats, image_shape, num_classes, pre_anchors, conf_thresh, nms_thresh, max_boxes=20):
    '''
    feats:代表模型的三个有效输出特征层
    image_shape:原始图像的(h,w)
    max_boxes:最大预测框数量
    conf_thresh:分类概率小于这个值的框被删除
    nms_thresh:两个框计算交并比, iou小于这个值被保留, 删除重复的框
    '''
    # 每个特征层对应的三个先验框
    anchor_mask = [[6,7,8], [3,4,5], [0,1,2]]

    boxes = []  # 存放每张图的预测框坐标
    box_scores = []  # 每个预测框的类别概率

    # 获得网络输入的图像的shape=416*416
    inputs_shape = tf.constant([416,416])
    
    # 获取三个有效输出特征层的预测框坐标和类别概率
    for i in range(len(feats)):
        #获取每个输出层的预测框坐标和概率
        feat_boxes, feat_box_scores = anchors_prediction(feats=feats[i], 
                                                         anchors=pre_anchors[anchor_mask[i]],  # 每个输出层对应的3个先验框
                                                         num_classes=num_classes,  # 分类数
                                                         inputs_shape=inputs_shape,  # 网络的输出图像大小[416,416]
                                                         image_shape=image_shape)  # 原始图像大小
        # 保存每一层的预测框坐标和概率
        # feat_boxes.shape=[2028,4], feat_box_scores.shape=[2028,20]
        boxes.append(feat_boxes)
        box_scores.append(feat_box_scores)

    # 调整排序方式，放在一行中 [[], [], []]==>[ , , ]
    boxes = tf.concat(boxes, axis=0)
    box_scores = tf.concat(box_scores, axis=0)

    # 设置一个阈值，保留概率高于该阈值的预测框
    # mask.shape=[3*2028, 20]
    mask = (box_scores) > conf_thresh
    # 设置每张图片最多出现几个预测框
    max_boxes = tf.constant(max_boxes, tf.int32)

    '''
    box_score是每个检测框对应的20个类别的概率
    mask是每个检测框的20个类别的概率是否满足要求
    '''

    # 保存每个预测框的坐标, 概率, 类别名称
    predict_boxes = []
    predict_score = []
    predict_classes = []

    # 对每一个类别删除重复的框
    for c in range(num_classes):
        # 取出所有类为c且满足最小概率要求的box
        class_boxes = tf.boolean_mask(boxes, mask[:, c])
        # 取出所有类为c的分数
        class_box_scores = tf.boolean_mask(box_scores[:, c], mask[:, c])
        # 非极大抑制
        nms_index = tf.image.non_max_suppression(class_boxes, class_box_scores, max_boxes, iou_threshold = nms_thresh)

        # 取出筛选后的结果
        class_boxes = tf.gather(class_boxes, nms_index)  # 检测框坐标
        class_box_scores = tf.gather(class_box_scores, nms_index)  # 每个框的类别概率
        classes = tf.ones_like(class_box_scores, dtype=tf.int32) * c  # 1*c获得类别索引

        # 保存筛选后的预测框坐标、概率、类别
        predict_boxes.append(class_boxes)
        predict_score.append(class_box_scores)
        predict_classes.append(classes)

    # 重新排序预测框, 放在同一行
    predict_boxes = tf.concat(predict_boxes, axis=0)
    predict_score = tf.concat(predict_score, axis=0)
    predict_classes = tf.concat(predict_classes, axis=0)

    return predict_boxes, predict_score, predict_classes


# 验证
if __name__ == '__main__':

    # 构造网络的三个输出特征层
    feat1 = tf.random.normal([4,13,13,75], mean=0, stddev=0.5) 
    feat2 = tf.random.normal([4,13,13,75], mean=0, stddev=0.5)  
    feat3 = tf.random.normal([4,13,13,75], mean=0, stddev=0.5) 
    feats = [feat1, feat2, feat3]
    
    # 先验框宽高
    anchors = np.array([[12, 16],  [19, 36],  [40, 28],  [36, 75],  [76, 55],  [72, 146],  [142, 110],  [192, 243],  [459, 401]])
    # 获得调整后的预测框
    predict_boxes, predict_score, predict_classes = predict(feats, image_shape=[1280,720], num_classes=20, pre_anchors=anchors, conf_thresh=0.5, nms_thresh=0.4, max_boxes=20)

    print(predict_boxes, predict_score, predict_classes)